{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "241de3c2-3361-4cc9-a52d-08d08145a112",
   "metadata": {},
   "source": [
    "# 数据构造"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9be65ac-fed3-4122-9396-1205e8fae5fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "prompt = \"\"\"请根据后续提供的的文档片段，生成3组问题-答案对（可以根据文档片段以及问题得出答案），返回请以[{{\"question\":xx,\"answer\":xx}}]的格式。\n",
    "\n",
    "以下是文档片段：\n",
    "{}\"\"\"\n",
    "\n",
    "token = 'f97b28aa71892d6d515007ef44ee5de6.WvLXD1MsFsGklOZa'\n",
    "\n",
    "MODEL = \"glm4-flash\"\n",
    "\n",
    "from zhipuai import ZhipuAI\n",
    "client = ZhipuAI(api_key=token)\n",
    "\n",
    "\n",
    "for line in open('heishenhua.txt','r', encoding='utf8'):\n",
    "    messages = [{\"role\": \"user\", \"content\": prompt.format(line)}]\n",
    "    try:\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"glm-4-flash\",\n",
    "            messages=messages\n",
    "        )\n",
    "        result = response.choices[0].message.content\n",
    "    except:\n",
    "        continue\n",
    "    fw = open('heishenhua_qa.txt', 'a', encoding='utf8')\n",
    "    fw.write(json.dumps({'chunk':line, 'result':result}, ensure_ascii=False) + '\\n')\n",
    "    fw.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31dbbc5f-4503-4e85-82c4-af068cca7544",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "27e64764-625b-49cc-a7f5-1e53d075a9b3",
   "metadata": {},
   "source": [
    "# 向量训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a2dfffa-365b-442a-a269-c3fd86575db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "data = []\n",
    "for line in open('heishenhua_qa.txt',encoding='utf8'):\n",
    "    line = json.loads(line)\n",
    "\n",
    "    context = json.loads(line['chunk'])\n",
    "    context = '标题:{}\\n正文：{}'.format(context['title'],context['content'])\n",
    "\n",
    "    try:\n",
    "        result = json.loads(line['result'].lstrip('```json').rstrip('```'))\n",
    "    except:\n",
    "        \n",
    "        result = [eval(i) for i in line['result'].split('\\n') if i.startswith('[')]\n",
    "\n",
    "    for item in result:\n",
    "        if isinstance(item,list):\n",
    "            item = item[0]\n",
    "        data.append([context, item['question']])    \n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "df = pd.DataFrame(columns=['context','query'],data=data)\n",
    "\n",
    "df.drop_duplicates(subset=['query'],keep='first',inplace=True)\n",
    "df.to_csv('embedding_data.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70be864c-cdd9-4778-b74f-ee6956b60a32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bfc3eae5-d6bd-4e98-bb55-1b13472ef305",
   "metadata": {},
   "source": [
    "# SFT训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec7f2d5-450b-4c40-beea-2ae95e101d1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "\n",
    "data = []\n",
    "contexts = []\n",
    "\n",
    "for line in open('heishenhua_qa.txt',encoding='utf8'):\n",
    "    line = json.loads(line)\n",
    "\n",
    "    context = json.loads(line['chunk'])\n",
    "    context = '标题:{}\\n正文：{}'.format(context['title'],context['content'])\n",
    "\n",
    "    try:\n",
    "        result = json.loads(line['result'].lstrip('```json').rstrip('```'))\n",
    "    except:\n",
    "        \n",
    "        result = [eval(i) for i in line['result'].split('\\n') if i.startswith('[')]\n",
    "\n",
    "    for item in result:\n",
    "        if isinstance(item,list):\n",
    "            item = item[0]\n",
    "        data.append([context, item['question'],item['answer']])        \n",
    "    contexts.append(context)\n",
    "\n",
    "\n",
    "values = []\n",
    "for item in data:\n",
    "    item.append(random.sample(contexts,random.sample([4,5,6],1)[0]))\n",
    "    if item[0] in item[-1]:\n",
    "        item[-1].remove(item[0])\n",
    "\n",
    "    if random.random() < 0.03:\n",
    "        random.shuffle(item[-1])\n",
    "        item.append('抱歉，我不知道')\n",
    "    else:\n",
    "        item[-1].append(item[0])  \n",
    "        random.shuffle(item[-1])\n",
    "        item.append(item[2])\n",
    "    if len('\\n'.join(item[-2])) < 10000:\n",
    "        values.append(['\\n'.join(item[-2]),item[1], item[-1]])\n",
    "    \n",
    "import pandas as pd\n",
    "df = pd.DataFrame(columns=['context','query','answer'],data=values)\n",
    "df.to_csv('sft_data.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad356f91-d727-44e1-ae9d-94a9c2d6227a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0rc2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
